import numpy as np

from toolkit.utils import interpolate_y, setmin
from toolkit.het_block import het


@het(exogenous='Pi', policy='a', backward='uc')
def household(uc_p, Pi_p, a_grid, e_grid, rpost, rsub, atw, tax, w, L, T, beta_grid, sigma):
    """Single backward iteration step using endogenous gridpoint method for households with separable utility.

    L has to be a scalar or specified on full grid.
    T has to be the same dimension as e_grid.
    atw has to be the same dimension as e_grid
    """
    # backward step
    uc_nextgrid = ((1 + rsub) * beta_grid * Pi_p) @ uc_p
    c_nextgrid = uc_nextgrid ** (-1 / sigma)
    lhs = c_nextgrid + a_grid[np.newaxis, :] - (atw[:, np.newaxis] * L) * e_grid[:, np.newaxis] - T[:, np.newaxis]
    rhs = (1 + rpost) * a_grid
    a = interpolate_y(lhs, rhs, a_grid)
    setmin(a, a_grid[0])
    c = rhs[np.newaxis, :] + (atw[:, np.newaxis] * L) * e_grid[:, np.newaxis] + T[:, np.newaxis] - a
    uc = c ** (-sigma)

    # other outputs
    uce = uc * (atw[:, np.newaxis] * e_grid[:, np.newaxis])
    rev = np.tile(tax * w * (e_grid * L), (len(a_grid), 1)).T

    return uc, a, c, uce, rev


def mpcs(c, a, a_grid, r):
    """Approximate mpc, with symmetric differences where possible, exactly setting mpc=1 for constrained agents."""
    mpcs_ = np.empty_like(c)
    post_return = (1 + r) * a_grid

    # symmetric differences away from boundaries
    mpcs_[:, 1:-1] = (c[:, 2:] - c[:, 0:-2]) / (post_return[2:] - post_return[:-2])

    # asymmetric first differences at boundaries
    mpcs_[:, 0] = (c[:, 1] - c[:, 0]) / (post_return[1] - post_return[0])
    mpcs_[:, -1] = (c[:, -1] - c[:, -2]) / (post_return[-1] - post_return[-2])

    # special case of constrained
    mpcs_[a == a_grid[0]] = 1

    return mpcs_
